{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pickle.load(open('quant_data_old.pkl', 'rb'))\n",
    "type(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer_name, layer_data in data.items():\n",
    "    print(f'layer_name: {layer_name}')\n",
    "    for data_name in ['output', 'weight']:\n",
    "        if data_name in layer_data:\n",
    "            scale_bit = math.log(layer_data[data_name]['scale'], 0.5)\n",
    "            assert abs(scale_bit - round(scale_bit)) < 1e-6\n",
    "            print(f\"{data_name} scale_bit: {round(scale_bit)}, zero_point: {layer_data[data_name]['zero_point']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp = data['quant']['output']\n",
    "inp_data = torch.tensor(inp['data'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel = data['conv.0.conv']['weight']\n",
    "kernel_data = torch.tensor(kernel['data'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = data['conv.0.conv']['output']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "o = F.conv2d((inp_data-64), kernel_data, stride=2, padding=0)\n",
    "o = F.relu(o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.round(o / 2**10).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out['data']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['fc.2']['output']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(data['quant']['output']['data'][0].transpose([1,2,0]).astype(np.uint8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "## export some horse images\n",
    "import os\n",
    "os.chdir('../')\n",
    "from PIL import Image\n",
    "import torchvision\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n"
     ]
    }
   ],
   "source": [
    "CIFAR_ROOT = '/media/HD1/Datasets/cifar'\n",
    "trainset = torchvision.datasets.CIFAR10(\n",
    "    root=CIFAR_ROOT, train=True, download=False)\n",
    "print(trainset.classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "out_dir = 'data/horse_samples'\n",
    "os.makedirs(out_dir, exist_ok=True)\n",
    "count = 0\n",
    "for img, label in trainset:\n",
    "    if label == trainset.class_to_idx['horse']:\n",
    "        img.save(f'{out_dir}/{count:04d}.png')\n",
    "        count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "['airplane',\n",
       " 'automobile',\n",
       " 'bird',\n",
       " 'cat',\n",
       " 'deer',\n",
       " 'dog',\n",
       " 'frog',\n",
       " 'horse',\n",
       " 'ship',\n",
       " 'truck']"
      ]
     },
     "metadata": {},
     "execution_count": 9
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}